package edu.berkeley.nlp.math;

import edu.berkeley.nlp.mapper.AsynchronousMapper;
import edu.berkeley.nlp.mapper.SimpleMapper;
import edu.berkeley.nlp.util.Pair;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;

/**
 * User: aria42
 * Date: Mar 10, 2009
 */
public class CachingObjectiveDifferentiableFunction<I> extends CachingDifferentiableFunction {

  private List<? extends ObjectiveItemDifferentiableFunction<I>> itemFns;
  private Regularizer regularizer;
  private Collection<I> items;

  public CachingObjectiveDifferentiableFunction(Collection<I> items,
                                                List<? extends ObjectiveItemDifferentiableFunction<I>> itemFns,
                                                Regularizer regularizer)
  {
    this.itemFns = itemFns;
    this.regularizer = regularizer;
    this.items = items;
  }

    public CachingObjectiveDifferentiableFunction(Collection<I> items,
                                                  ObjectiveItemDifferentiableFunction<I> itemFn,
                                                  Regularizer regularizer)
  {
    this(items, Collections.singletonList(itemFn),regularizer);
  }

  private class Mapper implements SimpleMapper<I> {
    ObjectiveItemDifferentiableFunction<I> itemFn;
    double objVal ;
    double[] localGrad ;
    Mapper(ObjectiveItemDifferentiableFunction<I> itemFn) {
      this.itemFn = itemFn;
      this.objVal = 0.0;
      this.localGrad = new double[itemFn.dimension()];
    }
    public void map(I elem) {
      objVal += itemFn.update(elem,localGrad);  
    }
  }

  private List<Mapper> getMappers() {
    List<Mapper> mappers = new ArrayList<Mapper>();
    for (ObjectiveItemDifferentiableFunction<I> itemFn : itemFns) {
      mappers.add(new Mapper(itemFn));
    }
    return mappers;
  }

  protected Pair<Double, double[]> calculate(double[] x) {
    for (ObjectiveItemDifferentiableFunction<I> itemFn : itemFns) {
      itemFn.setWeights(x);
    }
    List<Mapper> mappers = getMappers();
    AsynchronousMapper.doMapping(items,mappers);
    double objVal = 0.0;
    double[] grad = new double[dimension()];
    for (Mapper mapper : mappers) {
      objVal += mapper.objVal;
      DoubleArrays.addInPlace(grad,mapper.localGrad);
    }
    if (regularizer != null) {
      objVal += regularizer.update(x,grad,1.0);
    }
    return Pair.newPair(objVal,grad);
  }

  public int dimension() {
    return itemFns.get(0).dimension();
  }
}
